import pytest
from tests.utils.model import new_model
from gpustack.scheduler.scheduler import evaluate_pretrained_config
from gpustack.schemas.models import CategoryEnum, BackendEnum


@pytest.mark.parametrize(
    "case_name, model, expect_error, expect_error_match, expect_categories",
    [
        (
            # Checkpoint:
            # The model contains custom code but `--trust-remote-code` is not provided.
            # This should raise a ValueError with a specific message.
            "custom_code_without_trust_remote_code",
            new_model(
                1,
                "test_name",
                1,
                huggingface_repo_id="microsoft/Phi-4-multimodal-instruct",
                backend=BackendEnum.VLLM,
                backend_parameters=[],
            ),
            ValueError,
            "The model contains custom code that must be executed to load correctly. If you trust the source, please pass the backend parameter `--trust-remote-code` to allow custom code to be run.",
            None,
        ),
        (
            # Checkpoint:
            # The model contains custom code and `--trust-remote-code` is provided.
            # This should pass without errors and set the model category to LLM.
            "custom_code_with_trust_remote_code",
            new_model(
                1,
                "test_name",
                1,
                huggingface_repo_id="microsoft/Phi-4-multimodal-instruct",
                backend=BackendEnum.VLLM,
                backend_parameters=["--trust-remote-code"],
            ),
            None,
            None,
            ["LLM"],
        ),
        (
            # Checkpoint:
            # The model is of an unsupported architecture.
            # This should raise a ValueError with a specific message.
            "unsupported_architecture",
            new_model(
                1,
                "test_name",
                1,
                huggingface_repo_id="google-t5/t5-base",
                backend=BackendEnum.VLLM,
                backend_parameters=[],
            ),
            ValueError,
            "Unsupported architecture:",
            None,
        ),
        (
            # Checkpoint:
            # The model is of an unsupported architecture using custom backend.
            # This should pass without errors.
            "pass_unsupported_architecture_custom_backend",
            new_model(
                1,
                "test_name",
                1,
                huggingface_repo_id="google-t5/t5-base",
                backend=BackendEnum.CUSTOM,
                backend_parameters=[],
            ),
            None,
            None,
            None,
        ),
        (
            # Checkpoint:
            # The model is of an unsupported architecture using custom backend version.
            # This should pass without errors.
            "pass_unsupported_architecture_custom_backend_version",
            new_model(
                1,
                "test_name",
                1,
                huggingface_repo_id="google-t5/t5-base",
                backend=BackendEnum.VLLM,
                backend_version="custom_version",
                backend_parameters=[],
            ),
            None,
            None,
            None,
        ),
        (
            # Checkpoint:
            # The model is of a supported architecture.
            # This should pass without errors.
            "supported_architecture",
            new_model(
                1,
                "test_name",
                1,
                huggingface_repo_id="Qwen/Qwen2.5-0.5B-Instruct",
                backend=BackendEnum.VLLM,
                backend_parameters=[],
            ),
            None,
            None,
            ["LLM"],
        ),
    ],
)
@pytest.mark.asyncio
async def test_evaluate_pretrained_config(
    config, case_name, model, expect_error, expect_error_match, expect_categories
):
    try:
        if expect_error:
            with pytest.raises(expect_error, match=expect_error_match):
                await evaluate_pretrained_config(model)
        else:
            await evaluate_pretrained_config(model)
            if expect_categories:
                assert model.categories == [CategoryEnum[c] for c in expect_categories]
    except AssertionError as e:
        raise AssertionError(f"Test case '{case_name}' failed: {e}") from e
